//
//  entitleMe.swift
//  Taurine
//
//  Created by CoolStar on 3/1/21.
//

import Foundation

func getSafeEntitlements() -> [String] {
    let CS_OPS_ENTITLEMENT_BLOB = UInt32(7)
    
    var hdr = [UInt32](repeating: 0, count: 2)
    csops(getpid(), CS_OPS_ENTITLEMENT_BLOB, &hdr, UInt32(MemoryLayout<UInt32>.size * 2))
    
    let len = UInt32(bigEndian: hdr[1])
    var buf = [UInt8](repeating: 0, count: Int(len))
    _ = buf.withUnsafeMutableBufferPointer {
        $0.baseAddress?.withMemoryRebound(to: UInt32.self, capacity: Int(len)){
            csops(getpid(), CS_OPS_ENTITLEMENT_BLOB, $0, len)
        }
    }
    let entitlements = Data(buf[8..<buf.count])
    if let plist = try? PropertyListSerialization.propertyList(from: entitlements, options: [], format: nil) {
        if let dict = plist as? [String: Any] {
            return dict.map { $0.key }
        }
    }
    return []
}

class EntitleMe {
    #if DEBUG
    private static var entitleMeInstances = 0
    #endif
    
    let offsets = Offsets.shared
    private var electra: Electra
    
    private var backup: [UInt64:UInt64] = [:]
    
    private var target_pid: pid_t = 0
    private var has_entitlements = false
    
    init(electra: Electra) {
        #if DEBUG
        EntitleMe.entitleMeInstances += 1
        
        guard (EntitleMe.entitleMeInstances <= 1) else {
            fatalError("Electra API Misuse!!!")
        }
        #endif
        
        self.electra = electra
    }
    
    public func extractCTBypass() -> Bool {
        unlink("/var/containers/Bundle/Application/ctbypass")
        guard extractZstd(source: "ctbypass", dest: "/var/containers/Bundle/Application/ctbypass") else {
            print("failed to extract ctbypass")
            return false
        }
        chown("/var/containers/Bundle/Application/ctbypass", 0, 0)
        chmod("/var/containers/Bundle/Application/ctbypass", 0755)
        
        return true
    }
    
    public func grabEntitlements(path: String, wantedEntitlements: [String]) -> Bool {
        guard !has_entitlements else {
            return false
        }
        
        backup = [:]
        
        let safeEntitlements = getSafeEntitlements()
        
        var wantedEntitlements = wantedEntitlements
        
        let our_proc = electra.our_proc
        
        let our_ucred = rk64ptr(our_proc + offsets.proc.ucred)
        let our_entitlements = rk64ptr(rk64ptr(our_ucred + offsets.ucred.cr_label) + 0x8)
        
        let our_entitlementcount = rk32(our_entitlements + Offsets.shared.osobject.os_dict_count)
        
        let our_entriessize = Int(our_entitlementcount) * MemoryLayout<dict_entry_t>.size
        let our_os_dict_entries = malloc(our_entriessize)
        
        let our_entry_ptr = rk64ptr(our_entitlements + Offsets.shared.osobject.os_dict_dict_entry)
        kread(our_entry_ptr, our_os_dict_entries, our_entriessize)
        
        var keyValAddrs: [[UInt64]] = []
        
        var idx = 0
        iterate_keys_in_dict(our_os_dict_entries?.assumingMemoryBound(to: dict_entry_t.self), our_entitlementcount) { (key, val) in
            defer { idx += 1 }
            
            let key_len = (rk32(key + self.offsets.osobject.os_string_len) >> 0xe) - 1
            let key_str = rk64ptr(key + self.offsets.osobject.os_string_string)
            
            var buf = [UInt8](repeating: 0, count: Int(key_len + 1))
            kread(key_str, &buf, Int(key_len))
            
            let keystr = String(cString: buf)
            if wantedEntitlements.contains(keystr){
                wantedEntitlements.removeAll { $0 == keystr }
                print("We already have this entitlement. Skipping", keystr)
                return
            }
            
            if !safeEntitlements.contains(keystr) {
                print("Entitlement is not safe. Skipping", keystr)
                return
            }
            
            let entitlement_keyaddr = our_entry_ptr + UInt64(0x10 * idx)
            let entitlement_valaddr = entitlement_keyaddr + 0x8
            
            if (rk64(entitlement_keyaddr) != key) || (rk64(entitlement_valaddr) != val) {
                print("This shouldn't happen???")
                return
            }
            
            self.backup[entitlement_keyaddr] = key
            self.backup[entitlement_valaddr] = val
            
            keyValAddrs.append([entitlement_keyaddr,entitlement_valaddr])
        }
        
        guard backup.count >= wantedEntitlements.count * 2 else {
            return false
        }
        
        free(our_os_dict_entries)
        
        //Hijack the entitlement -- again ;))
        
        var attrp: posix_spawnattr_t?
        posix_spawnattr_init(&attrp)
        posix_spawnattr_setflags(&attrp, Int16(POSIX_SPAWN_START_SUSPENDED))
        
        let args = [path]
        
        let argv: [UnsafeMutablePointer<CChar>?] = args.map { $0.withCString(strdup) }
        defer { for case let arg? in argv { free(arg) } }
        
        var pid: pid_t = 0
        let retVal = posix_spawn(&pid, path.cString(using: .utf8), nil, &attrp, argv + [nil], environ)
        if retVal < 0 {
            return false
        }
        
        self.target_pid = pid
        
        let target_proc = electra.find_proc(pid: UInt32(pid))
        guard target_proc != 0 else {
            return false
        }
        
        let target_ucred = rk64ptr(target_proc + offsets.proc.ucred)
        let target_entitlements = rk64ptr(rk64ptr(target_ucred + offsets.ucred.cr_label) + 8)
        
        let target_entitlementcount = rk32(target_entitlements + Offsets.shared.osobject.os_dict_count)
        
        let target_entriessize = Int(target_entitlementcount) * MemoryLayout<dict_entry_t>.size
        let target_os_dict_entries = malloc(target_entriessize)
        
        let target_entry_ptr = rk64ptr(target_entitlements + Offsets.shared.osobject.os_dict_dict_entry)
        kread(target_entry_ptr, target_os_dict_entries, target_entriessize)
        
        iterate_keys_in_dict(target_os_dict_entries?.assumingMemoryBound(to: dict_entry_t.self), target_entitlementcount) { (key, val) in
            let key_len = (rk32(key + self.offsets.osobject.os_string_len) >> 0xe) - 1
            let key_str = rk64ptr(key + self.offsets.osobject.os_string_string)
            
            var buf = [UInt8](repeating: 0, count: Int(key_len + 1))
            kread(key_str, &buf, Int(key_len))
            
            let keystr = String(cString: buf)
            if wantedEntitlements.contains(keystr) {
                print("Found an entitlement!", keystr)
                
                let keyValAddr = keyValAddrs.removeFirst()
                
                wk64(keyValAddr[0], key)
                wk64(keyValAddr[1], val)
                
                wantedEntitlements.removeAll { $0 == keystr }
            }
        }
        
        free(target_os_dict_entries)
        
        has_entitlements = true
        
        return wantedEntitlements.count == 0
    }
    
    public func resetEntitlements(){
        guard has_entitlements else {
            return
        }
        
        for (addr,val) in backup {
            wk64(addr, val)
        }
        
        backup = [:]
        
        kill(target_pid, SIGKILL)
        
        unlink("/var/containers/Bundle/Application/ctbypass")
        
        has_entitlements = false
    }
}
